import tensorflow as tf
from tensorflow.keras.applications.vgg16 import preprocess_input
import os
import glob

def get_dataset(train_dirs, val_dir, batch_size=64, shuffle= 5000, pref= 2, shuffle_val=False):
    # List all four part
    val_files = tf.data.Dataset.list_files(os.path.join(val_dir, "*/*.JPEG"), shuffle=False) #.take(100)
    # Use all subfolders as one dataset
    train_ds = tf.data.Dataset.list_files(
        [os.path.join(d, "*/*.JPEG") for d in train_dirs],
        shuffle=True
    )
    IMG_SIZE = 224  # For ImageNet compatibility
    def process_image(file_path):
        # Extract label from the directory name
        parts = tf.strings.split(file_path, os.sep)
        label = parts[-2]  # 'n01440764' for example
        # Convert label to integer using a lookup table
        table = tf.lookup.StaticHashTable(
            initializer=tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(label_names),  # Defined below
                values=tf.constant(list(range(len(label_names))))
            ),
            default_value=-1
        )
        label_id = table.lookup(label)
        image = tf.io.read_file(file_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
        image = preprocess_input(image)
    #    image = tf.cast(image, tf.float16) / 255.0  # Normalize
        return image, label_id
    all_dirs = train_dirs
    label_names = sorted(set(
        folder
        for d in all_dirs
        for folder in os.listdir(d)
        if os.path.isdir(os.path.join(d, folder))
    ))
    
    #AUTOTUNE = tf.data.AUTOTUNE
    train_ds = train_ds.map(process_image, num_parallel_calls=4) # replaced 4 with AUTOTUNE
    train_ds = train_ds.shuffle(shuffle)
    train_ds = train_ds.batch(batch_size).repeat()
    train_ds = train_ds.prefetch(pref)
    train_ds = train_ds.apply(tf.data.experimental.ignore_errors())
    options = tf.data.Options()
    options.experimental_threading.max_intra_op_parallelism = 1
    options.experimental_threading.private_threadpool_size = 4
    train_ds = train_ds.with_options(options)
    val_ds = val_files.map(process_image, num_parallel_calls=4)
    val_ds = val_ds.batch(batch_size).prefetch(pref)
    return train_ds, val_ds